{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4b089493",
   "metadata": {},
   "source": [
    "# Simulated Environment: Gymnasium\n",
    "\n",
    "For many applications of LLM agents, the environment is real (internet, database, REPL, etc). However, we can also define agents to interact in simulated environments like text-based games. This is an example of how to create a simple agent-environment interaction loop with [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) (formerly [OpenAI Gym](https://github.com/openai/gym))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f36427cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install gymnasium"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f9bd38b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import inspect\n",
    "import tenacity\n",
    "\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.schema import (\n",
    "    AIMessage,\n",
    "    HumanMessage,\n",
    "    SystemMessage,\n",
    "    BaseMessage,\n",
    ")\n",
    "from langchain.output_parsers import RegexParser"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e222e811",
   "metadata": {},
   "source": [
    "## Define the agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "870c24bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GymnasiumAgent:\n",
    "    @classmethod\n",
    "    def get_docs(cls, env):\n",
    "        return env.unwrapped.__doc__\n",
    "\n",
    "    def __init__(self, model, env):\n",
    "        self.model = model\n",
    "        self.env = env\n",
    "        self.docs = self.get_docs(env)\n",
    "\n",
    "        self.instructions = \"\"\"\n",
    "Your goal is to maximize your return, i.e. the sum of the rewards you receive.\n",
    "I will give you an observation, reward, terminiation flag, truncation flag, and the return so far, formatted as:\n",
    "\n",
    "Observation: <observation>\n",
    "Reward: <reward>\n",
    "Termination: <termination>\n",
    "Truncation: <truncation>\n",
    "Return: <sum_of_rewards>\n",
    "\n",
    "You will respond with an action, formatted as:\n",
    "\n",
    "Action: <action>\n",
    "\n",
    "where you replace <action> with your actual action.\n",
    "Do nothing else but return the action.\n",
    "\"\"\"\n",
    "        self.action_parser = RegexParser(\n",
    "            regex=r\"Action: (.*)\", output_keys=[\"action\"], default_output_key=\"action\"\n",
    "        )\n",
    "\n",
    "        self.message_history = []\n",
    "        self.ret = 0\n",
    "\n",
    "    def random_action(self):\n",
    "        action = self.env.action_space.sample()\n",
    "        return action\n",
    "\n",
    "    def reset(self):\n",
    "        self.message_history = [\n",
    "            SystemMessage(content=self.docs),\n",
    "            SystemMessage(content=self.instructions),\n",
    "        ]\n",
    "\n",
    "    def observe(self, obs, rew=0, term=False, trunc=False, info=None):\n",
    "        self.ret += rew\n",
    "\n",
    "        obs_message = f\"\"\"\n",
    "Observation: {obs}\n",
    "Reward: {rew}\n",
    "Termination: {term}\n",
    "Truncation: {trunc}\n",
    "Return: {self.ret}\n",
    "        \"\"\"\n",
    "        self.message_history.append(HumanMessage(content=obs_message))\n",
    "        return obs_message\n",
    "\n",
    "    def _act(self):\n",
    "        act_message = self.model(self.message_history)\n",
    "        self.message_history.append(act_message)\n",
    "        action = int(self.action_parser.parse(act_message.content)[\"action\"])\n",
    "        return action\n",
    "\n",
    "    def act(self):\n",
    "        try:\n",
    "            for attempt in tenacity.Retrying(\n",
    "                stop=tenacity.stop_after_attempt(2),\n",
    "                wait=tenacity.wait_none(),  # No waiting time between retries\n",
    "                retry=tenacity.retry_if_exception_type(ValueError),\n",
    "                before_sleep=lambda retry_state: print(\n",
    "                    f\"ValueError occurred: {retry_state.outcome.exception()}, retrying...\"\n",
    "                ),\n",
    "            ):\n",
    "                with attempt:\n",
    "                    action = self._act()\n",
    "        except tenacity.RetryError as e:\n",
    "            action = self.random_action()\n",
    "        return action"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e76d22c",
   "metadata": {},
   "source": [
    "## Initialize the simulated environment and agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9e902cfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"Blackjack-v1\")\n",
    "agent = GymnasiumAgent(model=ChatOpenAI(temperature=0.2), env=env)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2c12b15",
   "metadata": {},
   "source": [
    "## Main loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ad361210",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Observation: (15, 4, 0)\n",
      "Reward: 0\n",
      "Termination: False\n",
      "Truncation: False\n",
      "Return: 0\n",
      "        \n",
      "Action: 1\n",
      "\n",
      "Observation: (25, 4, 0)\n",
      "Reward: -1.0\n",
      "Termination: True\n",
      "Truncation: False\n",
      "Return: -1.0\n",
      "        \n",
      "break True False\n"
     ]
    }
   ],
   "source": [
    "observation, info = env.reset()\n",
    "agent.reset()\n",
    "\n",
    "obs_message = agent.observe(observation)\n",
    "print(obs_message)\n",
    "\n",
    "while True:\n",
    "    action = agent.act()\n",
    "    observation, reward, termination, truncation, info = env.step(action)\n",
    "    obs_message = agent.observe(observation, reward, termination, truncation, info)\n",
    "    print(f\"Action: {action}\")\n",
    "    print(obs_message)\n",
    "\n",
    "    if termination or truncation:\n",
    "        print(\"break\", termination, truncation)\n",
    "        break\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58a13e9c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
